import scipy.io
import numpy as np
import pandas as pd
import logging
import re
import copy
from IO.EssayPart import EssayPart

class CageViewer(EssayPart):
    #Overwrite column autodetect
    resample_how = {
        'TrialCnt': EssayPart.lambda_most_common
    }

    dtypes = {
        #'TrialType': 'category',
        #'PeriodType': 'category',
        #'Lick_Diff': 'category',
        #'NosePoke_Diff': 'category',
        #'Group': 'category',
        #'Session': 'category',
        #'MouseID': 'category',
        'BlueCueLight': 'bool',
        'HouseLight': 'bool',
        'Lick': 'bool',
        'Lick_OnSet': 'int',
        'Lick_OffSet': 'int',
        'NosePoke': 'bool',
        'NosePoke_OnSet': 'int',
        'NosePoke_OffSet': 'int',
        'Shock': 'bool',
        'Valve': 'bool',
        'YellowCueLight': 'bool',
        'Sound': 'int'}
    """
    Load matlab files generated with CageViewer

    >>> E = CageViewer('testdata/000000_P03_S01_A1C1.mat', parseFileName='(?P<Phase>[A-Z\d]+)_(?P<Session>[A-Z\d]+)_(?P<MouseID>[A-Z\d]+)\.mat')
    >>> E
    CageViewer from 2016-01-10 07:53:47 on mouse [S01/P03/A1C1] for 2 Trials
    >>> E2 = CageViewer('testdata/000000_P03_S01_A1C1.mat')
    >>> E2
    CageViewer from 2016-01-10 07:53:47 on mouse [1/3/A1C1] for 2 Trials
    """
    def __init__(self, matfilepath, parseFileName=None, resetSensorOnPeriodStart=False, roundTimeIndex=False):
        try:
            self.filename = matfilepath
            self.rawData = scipy.io.loadmat(matfilepath, squeeze_me=True, struct_as_record=False)

            self.nTrials = self.rawData['periodsExecuted'][-1, 1]
            self.nPoints = self.rawData['nLoopIterations']
            self.periodsExecuted = self.rawData['periodsExecuted']
            self.start = pd.to_datetime(self.rawData['startDate'])
            self.end = pd.to_datetime(self.rawData['endDate'])

            logging.debug('Experiment contains %d Presentations, %d loop Iteratios, peformed from %s until %s', self.nTrials, self.nPoints, self.start, self.end)

            sensorD = {}
            for i, sD in enumerate(self.rawData['sensorDataMeta']):
                sensorD[sD] = self.rawData['sensorData'][:, i]

            self.data = pd.DataFrame(sensorD)

            self.meta = {}
            self.meta['Group'] = ''
            self.meta['MouseID'] = ''
            self.meta['Phase'] = ''
            self.meta['Session'] = ''
            if self.rawData['protocolMeta'].group:
                self.meta['Group'] = str(self.rawData['protocolMeta'].group)
            if self.rawData['protocolMeta'].animalNumber:
                self.meta['MouseID'] = str(self.rawData['protocolMeta'].animalNumber)
            if self.rawData['protocolMeta'].phase:
                self.meta['Phase'] = str(self.rawData['protocolMeta'].animalNumber)
            if self.rawData['protocolMeta'].session:
                self.meta['Session'] = str(self.rawData['protocolMeta'].animalNumber)

            #Try to parse filename if a parse String is specified
            #parseFileName = '(?P<Treatment>[A-Z\d]+)_(?P<Group>[A-Z\d]+)_(?P<MouseID>[A-Z\d]+)\.mat'
            #parseFileName = '(?P<Phase>[A-Z\d]+)_(?P<Session>[A-Z\d]+)_(?P<MouseID>[A-Z\d]+)\.mat'
            if parseFileName:
                self.meta.update(self._parseFileName(matfilepath, parseFileName))

            logging.debug('Adding Experiment meta data %s', self.meta)
            for g, v in self.meta.items():
                self.data.insert(0, g, v)
                #self.data[g] = self.data[g].astype('category')

            self.data = self._addTrialInformation(self.data, self.rawData)
            #Fuse Premature and Precue
            self.data['PeriodType'].replace(['N.*Premature', 'G.*Premature'], ['NG_Precue', 'Go_Precue'], inplace=True, regex=True)

            how = self.autoDetectResampleMethods()
            how.update(self.resample_how)
            self.resample_how = how
            how = EssayPart.filterdic(self.resample_how, 'max')
            for c in how:
                self.createDiffColumn(c, self.periodsExecuted if resetSensorOnPeriodStart else None)

            #Test with
            #pd.concat([E.data[['TrialCnt', 'PeriodType']], E.data[['Lick', 'NosePoke', 'Lick_OnSet', 'NosePoke_OnSet']].astype('int')], axis=1).to_csv('OnSet_withPers.csv')
            #self.createDiffColumn('Lick',self.periodsExecuted if resetSensorOnPeriodStart else None)
            #self.createDiffColumn('NosePoke',self.periodsExecuted if resetSensorOnPeriodStart else None)

            self.data.index = pd.to_timedelta(self.data['Time'], unit='s')
            #Round the timeindex to 100ms (can produce duplicates)
            if roundTimeIndex:
                self.data.set_index(self.data.index - self.data.index[0], inplace=True)
                self.roundIndex();

            #Calculate the duration of each sample period in seconds
            times = self.data.index.values.tolist()
            times = [t / 1000000000 for t in np.diff(times)]
            times.append(np.mean(times))
            self.data['Duration'] = times

            #Drop Time column
            self.data.drop('Time', axis=1, inplace=True)
        except Exception as e:
            logging.error('Importing matlab data failed! %s', e)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return format('%s from %s on mouse [%s/%s/%s] for %d Trials' % (type(self).__name__, self.start, self.meta['Session'], self.meta['Phase'], self.meta['MouseID'], self.nTrials))

    @staticmethod
    def _addTrialInformation(table, rawData):
        """
        Add to given panda table for reach sample the Trial and Period information
        """

        trialNames = []
        periodNames = []
        for t in rawData['protocol'].TrialsDefinition:
            trialNames.append(t.name)
            pN = []
            for p in t.periods:
                pN.append(p.name)
            periodNames.append(pN)
            #logging.debug('Found Trial %s, with periods %s', t.name, pN)

        table.insert(0, 'PeriodType', 0)
        table.insert(0, 'TrialType', 0)
        table.insert(0, 'TrialCnt', 0)

        #Add Trial Count, Trial Type and Period Type to each sensor sample
        for r in rawData['periodsExecuted']:
            tCnt, tType, pType, onSet, offSet = r
            offSet = offSet

            table.loc[onSet:offSet, 'TrialCnt'] = tCnt
            table.loc[onSet:offSet, 'TrialType'] = trialNames[tType - 1]
            table.loc[onSet:offSet, 'PeriodType'] = periodNames[tType - 1][pType - 1]

        #table['Time'] = pd.to_datetime(table['Time'])
        return table

    @staticmethod
    def _parseFileName(filename, parseString):
        """
        Parse file name if possible. Parse string must be of similar structure to
        '(?P<Treatment>[A-Z\d]+)_(?P<Group>[A-Z\d]+)_(?P<MouseID>[A-Z\d]+)\.mat'
        This performs a regex search with named groups

        >>> match = CageViewer._parseFileName('test/inData/input/150810_P02_S01_A1C1.mat', '(?P<Treatment>[A-Z\d]+)_(?P<Group>[A-Z\d]+)_(?P<MouseID>[A-Z\d]+)\.mat')
        >>> match['Group']
        'S01'
        >>> match['MouseID']
        'A1C1'
        """
        match = re.search(parseString, filename)
        if match is None:
            return {}
        else:
            return match.groupdict()

    def generateEssayTable(self, others):
        #Round index to 100ms
        duration = pd.to_datetime(self.data.index[-1].value)
        table = self.data.copy()
        for o in self.IT(others):
            try:
                if self.meta['MouseID'] != o.meta['MouseID']:
                    logging.warning('Trying to mix data from different mice!')

                duration2 = pd.to_datetime(o.data.index[-1].value)
                if duration > duration2:
                    diff = duration - duration2
                else:
                    diff = duration2 - duration
                if diff.seconds > 1:
                    logging.warning('Source %s has a duration difference of %s', o, diff)

                idxMapping = EssayPart.assignTimeIndex(self.data.index, o.data.index)
                on = o.data.groupby(idxMapping).agg(o.resample_how)
                missingCol = on.columns.difference(table.columns)
                table = pd.merge(table, on[missingCol], how='left', left_index=True, right_index=True)
            except Exception as e:
                logging.error('Adding information from object %s failed! %s', o, e)

        #Set the time Column to the period
        return table

